{
 "cells": [
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Setup\n",
    "\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "!pip install --quiet poetry  # Fixes https://github.com/python-poetry/poetry/issues/532\n",
    "!pip install --quiet git+https://github.com/oughtinc/ergo.git@dccf8cd526a0e7367aa1fef086f042d7eaa53aa3\n",
    "!pip install --quiet plotnine"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "    ents to build wheel ... etadata ... ERROR: chainer 6.5.0 has requirement typing-extensions<=3.6.6, but you'll have typing-extensions 3.7.4.2 which is incompatible.\n",
    "\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "import requests\n",
    "import torch\n",
    "import io\n",
    "import zipfile\n",
    "import os\n",
    "import ergo\n",
    "import requests\n",
    "import scipy.stats\n",
    "\n",
    "import pandas as pd\n",
    "import numpy as np\n",
    "\n",
    "from datetime import date, datetime, timedelta\n",
    "from types import SimpleNamespace\n",
    "from typing import List\n",
    "import plotnine\n",
    "from plotnine import *"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "    /usr/local/lib/python3.6/dist-packages/statsmodels/tools/_testing.py:19: FutureWarning: pandas.util.testing is deprecated. Use the functions in the public API at pandas.testing instead.\n",
    "      import pandas.util.testing as tm\n",
    "\n",
    "Log into a Metaculus account\n",
    "\n",
    "If running in a collab notebook, please enter your Metaculus credentials\n",
    "here\n",
    "\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "def is_local():\n",
    "  try:\n",
    "    %env USER\n",
    "    return True\n",
    "  except:\n",
    "    return False\n",
    "\n",
    "metaculus_api = \"pandemic\"\n",
    "\n",
    "if is_local():\n",
    "  from dotenv import load_dotenv\n",
    "  load_dotenv() \n",
    "  metaculus = ergo.Metaculus(username=os.getenv(\"METACULUS_USERNAME\"), password=os.getenv(\"METACULUS_PASSWORD\"), api_domain=metaculus_api)\n",
    "else: \n",
    "  try:    \n",
    "    \n",
    "    metaculus = ergo.Metaculus(username=\"oughttest\", password=\"6vCo39Mz^rrb\", api_domain=metaculus_api)  \n",
    "  except: \n",
    "    print(f'WARNING, You will need to enter your metaculus credentials in this cell')"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Questions\n",
    "\n"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "Here are the question we want to forecast:\n",
    "\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "question_ids = [\n",
    "  3935,\n",
    "  3948,\n",
    "  3941,\n",
    "  3939,\n",
    "  3937\n",
    "]\n",
    "question_names = [\n",
    "    \"The United Kingdom\",\n",
    "    \"France\",\n",
    "    \"Poland\",\n",
    "    \"The State of California\",\n",
    "    \"Italy\"\n",
    "]\n",
    "\n",
    "\n",
    "areas = question_names\n",
    "questions = [metaculus.get_question(id, name=name) for (name, id) in zip(question_names, question_ids)]\n",
    "ergo.MetaculusQuestion.to_dataframe(questions)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Assumptions and Question Information\n",
    "\n"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "-   intensity&#x2014;a 1-5 corresponding to the severity/degree of adherence of\n",
    "    the social distancing, with 5 being the most strictly\n",
    "    observed/enforced social distancing\n",
    "-   re-evaluation&#x2014;date at which government said they would re-assess\n",
    "    lockdown\n",
    "\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "Area = dict\n",
    "assumptions = SimpleNamespace()\n",
    "\n",
    "assumptions.lockdowns = {\n",
    "  \"New York\" : {\n",
    "    \"data_key\" : \"New York\",\n",
    "    \"question_name\" : \"The State of New York\",\n",
    "    \"start\": date(2020, 3, 21),\n",
    "    \"intensity\": 3,\n",
    "    \"end\": None #pendulum.Date(2020, X, X)\n",
    "  },\n",
    "  \"California\" : {\n",
    "    \"data_key\" : \"California\",\n",
    "    \"question_name\" : \"The State of California\", \n",
    "    \"start\": date(2020, 3, 22),\n",
    "    \"end\": None #date(2020, X, X)\n",
    "  },\n",
    "  \"United Kingdom\": {\n",
    "    \"data_key\" : \"United Kingdom\",\n",
    "    \"question_name\" : \"The United Kingdom\", \n",
    "    \"start\": date(2020, 3, 23),\n",
    "    \"re_evaluated\": None, \n",
    "    \"intensity\": 3, \n",
    "    \"end\": None #date(2020, X, X)\n",
    "  },\n",
    "  \"France\": {\n",
    "    \"data_key\" : \"France\",\n",
    "    \"question_name\" : \"France\", \n",
    "    \"start\": date(2020, 3, 17),\n",
    "    \"re_evaluated\": date(2020, 4, 1),\n",
    "    \"intensity\": 3, \n",
    "    \"end\": None #date(2020, X, X)\n",
    "  },\n",
    "  \"Poland\": {\n",
    "    \"data_key\": \"Poland\",\n",
    "    \"question_name\" : \"Poland\",\n",
    "    \"start\": date(2020, 3, 25),\n",
    "    \"re_evaluated\": date(2020, 4, 11),\n",
    "    \"end\": None #date(2020, X, X)\n",
    "  },\n",
    "  \"Italy\": {\n",
    "    \"data_key\": \"Italy\",\n",
    "    \"question_name\" : \"Italy\",\n",
    "    \"start\": date(2020, 3, 23),\n",
    "    \"end\": None #date(2020, X, X)\n",
    "  }\n",
    "}"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "Helper Functions\n",
    "\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "#hack implementation of date.fromisoformat (is in datetime @ pyton 3.7)\n",
    "def fromisoformat( xdate): \n",
    "    return datetime.strptime(xdate, '%Y-%m-%d').date()\n",
    "\n",
    "# This allows access to the right assumptions using a question name. \n",
    "def get_assumptions(key, assumptions = assumptions.lockdowns):\n",
    "  if isinstance(key, str):\n",
    "    if key in assumptions:\n",
    "      return assumptions[key]\n",
    "    print(f\"No assumptions for data key: {key}\")\n",
    "  elif issubclass(type(key), ergo.metaculus.MetaculusQuestion):\n",
    "    for k,v in assumptions.items():\n",
    "      if v['question_name'] == key.name:\n",
    "        return v\n",
    "    print(f\"No assumptions for question: {question.name}\")\n",
    "  else:\n",
    "      print(f\"Neither a question nor a data_key was passed\")   \n",
    "\n",
    "# ripped from https://techoverflow.net/2018/01/16/downloading-reading-a-zip-file-in-memory-using-python/\n",
    "def download_extract_zip(url):\n",
    "    \"\"\"\n",
    "    Download a ZIP file and extract its contents in memory\n",
    "    yields (filename, file-like object) pairs\n",
    "    \"\"\"\n",
    "    response = requests.get(url)\n",
    "    with zipfile.ZipFile(io.BytesIO(response.content)) as thezip:\n",
    "        for zipinfo in thezip.infolist():\n",
    "            with thezip.open(zipinfo) as thefile:\n",
    "                yield zipinfo.filename, thefile"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Data\n",
    "\n"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "We now have predictions from\n",
    "[http://www.healthdata.org/covid/data-downloads](http://www.healthdata.org/covid/data-downloads)\n",
    "\n",
    "***TODO*** if something like this is going to persist, then consider\n",
    "migrating to ergo/data/covid19.py\n",
    "\n",
    "Get Data\n",
    "\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "url=\"https://ihmecovid19storage.blob.core.windows.net/latest/ihme-covid19.zip\"\n",
    "\n",
    "for name, xfile in download_extract_zip(url):\n",
    "    if os.path.basename(name) ==\"Hospitalization_all_locs.csv\":\n",
    "        infections_df = pd.read_csv(xfile)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "Clean and Enhance Data\n",
    "\n",
    "**variable description**\n",
    "\n",
    "-   admis_mean = daily # of admissions to hospital\n",
    "-   allbed_mean = cumsum of admis_mean\n",
    "\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "#filter data        \n",
    "infections_df = infections_df.loc[infections_df['location_name'].isin([\"New York\", \"California\", \"United Kingdom\", \"France\", \"Poland\",  \"Italy\"]),\n",
    "                                  ['location_name', 'date', 'admis_mean', 'newICU_mean', 'deaths_mean', 'totdea_mean', 'allbed_mean', 'ICUbed_mean', 'bedover_mean', 'icuover_mean']]\n",
    "\n",
    "#infections_df['location_name'].unique() \n",
    "\n",
    "# calculate days from lockdown\n",
    "def calulate_days_from_lockdown_start(df: pd.core.frame.DataFrame):\n",
    "    lockdown_start = get_assumptions(df['location_name'])['start']\n",
    "    return (lockdown_start - fromisoformat(df['date'])).days\n",
    "\n",
    "infections_df['days_from_lockdown'] = infections_df.apply(lambda x: calulate_days_from_lockdown_start(x), axis=1)\n",
    "\n",
    "# calculate cumulative addmissions\n",
    "infections_df['admis_cum'] = infections_df['admis_mean'].cumsum() \n",
    "\n",
    "# calculate doubling rate\n",
    "infections_df['doubling_rate_in_days'] = infections_df['admis_cum'] / infections_df['admis_mean']\n",
    "\n",
    "# calculate the new cases as a percentage of previous total\n",
    "infections_df['progression'] =  infections_df['admis_cum'] / infections_df['admis_mean']\n",
    "infections_df['progression'] = infections_df['progression'].apply(lambda x: min(x, 365)) # we don't care about doubling rates longer than a year (at least)\n",
    "\n",
    "# remove all NAs\n",
    "infections_df.fillna(0)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Explore Features\n",
    "\n"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### Hospital Admissions\n",
    "\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "plotnine.options.figure_size = (8,4)\n",
    "(ggplot(infections_df, aes('days_from_lockdown', 'admis_mean', color='location_name'))\n",
    "     + geom_point()\n",
    "     + geom_vline(xintercept=0) \n",
    "     + labs(x='Days since lockdown',  y='New hospital admissions', title='Progression of hospital admissions')\n",
    "  )"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "Examine the evolution of the rate of the spread of the infection\n",
    "\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "plotnine.options.figure_size = (12,4)\n",
    "(ggplot(infections_df, aes('days_from_lockdown', 'admis_mean', color='location_name'))\n",
    " + geom_point()\n",
    " + facet_wrap('~location_name', nrow=1)\n",
    " + labs(x='Days since lockdown',  y='New hospital admissions', title='Progression of hospital admissions')\n",
    " + guides(color=False)\n",
    " )"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### Bedover\n",
    "\n"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "[covid all beds needed] - ([total bed capacity] - [average all bed\n",
    "usage])\n",
    "\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "plotnine.options.figure_size = (10,4)\n",
    "(ggplot(infections_df, aes('days_from_lockdown', 'bedover_mean', color='location_name'))\n",
    " + geom_point()\n",
    " + facet_wrap('~location_name')\n",
    " + geom_vline(xintercept=0)\n",
    " + labs(x='Days since lockdown',  y='Mean Bedover', title='Projected Bed Utilization')\n",
    ")"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### Mean Beds needed for Covid cases\n",
    "\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "plotnine.options.figure_size = (10,4)\n",
    "(ggplot(infections_df, aes('days_from_lockdown', 'allbed_mean', color='location_name'))\n",
    "   + geom_point()\n",
    "   + facet_wrap('~location_name')\n",
    "   + geom_vline(xintercept=0)\n",
    "   + labs(x='Days since lockdown',  y='Mean Bedover', title='Projected Bed Utilization')\n",
    ")"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### Infection Rate\n",
    "\n"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "**TODO** Need to edit calculation where little or no data\n",
    "\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "plotnine.options.figure_size = (10,4)\n",
    "(ggplot(infections_df, aes('days_from_lockdown', 'progression', color='location_name'))\n",
    "   + geom_point()\n",
    "   + facet_wrap('~location_name')\n",
    "   + geom_vline(xintercept=0)\n",
    "   + labs(x='Days since lockdown',  y='Pseudo doubling rate in days', title='Disease Growth' )\n",
    ")"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Model\n",
    "\n"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "**Approach**\n",
    "\n",
    "*Simple Model* The decision to end a lockdown is considered every days\n",
    "from the initial order. The average hospital admissions across the\n",
    "previous weeks are considered. As the number of Covid-related hospital\n",
    "admissions decreases, the likelihood of suspending the lockdown\n",
    "increases.\n",
    "\n",
    "*Conditions to consider adding*\n",
    "\n",
    "-   deaths (relative to population)\n",
    "-   deaths (realtive to infected)\n",
    "-   ratio of recovery to newly infected\n",
    "-   % of population has been tested > threshold\n",
    "-   complex priors over dates (1st and 15th of month more likely. Perhaps\n",
    "    Holiday's)?\n",
    "-   depletion of susceptible stock 1 - (Infected + Recovered + Deaths) /\n",
    "    population > threshold\n",
    "\n",
    "*Ad-Hoc High-level features to include*\n",
    "\n",
    "-   Government Market Orientation\n",
    "-   Government Goal &#x2014; Signaling Action | Spread Mitigation\n",
    "\n",
    "Model\n",
    "\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "def model(area: Area, data: pd.core.frame.DataFrame):\n",
    "  # make sure reasonable things were passed in\n",
    "  if not \"location_name\" in data:\n",
    "    print(f'The data does not have the expected structure')\n",
    "    raise Exception(\"bad data\")\n",
    "  elif not area.get(\"data_key\") in data[\"location_name\"].values:\n",
    "    print(f'There is currently no infection data for {area[\"data_key\"]}')\n",
    "    raise Exception(\"no data\")\n",
    "  elif not area.get(\"start\"):\n",
    "        print(f'There is currently no lockdown for {area[\"data_key\"]}')\n",
    "        raise Exception(\"no lockdown\")\n",
    "\n",
    "  # model start\n",
    "  area_data = data.loc[data[\"location_name\"] == area.get(\"data_key\")] # get data for area\n",
    "  lockdown_duration = 0\n",
    "  last_period_spread_rate = max(area_data['admis_mean'])\n",
    "  while True:\n",
    "    lockdown_duration += 1\n",
    "    if (max(area_data['days_from_lockdown']) >= lockdown_duration): # keep using last infection_spread_rate if we run out of data (this is a bad hack)\n",
    "      # take the average of the hospital admissions for the past five weeks\n",
    "      infection_spread_rate = np.mean(area_data.loc[(area_data['days_from_lockdown'] > lockdown_duration - 35 ) &\n",
    "                                                    (area_data['days_from_lockdown'] <= lockdown_duration), 'admis_mean'])\n",
    "\n",
    "    # This logistic distribution gives the highest probability at 0 and decreases the larger the passed in value. The speed in which the probability drops off is modulated by the scale parameter. The lower the scale parameter, the lower probability assigned to high numbers. Here the .5 indicates a rather thin tail with a strong bias towards low numbers. \n",
    "    stop_quarantine = ergo.flip(scipy.stats.logistic.pdf(infection_spread_rate, scale = .5))\n",
    "    last_period_spread_rate = infection_spread_rate \n",
    "    if(stop_quarantine):\n",
    "        break\n",
    "\n",
    "  ergo.tag(torch.Tensor([lockdown_duration]), area.get(\"data_key\"))"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "Run Model\n",
    "\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "samples = pd.DataFrame() #unconditioned\n",
    "\n",
    "for question in questions:\n",
    " samples[question.name] = ergo.run(lambda: model(get_assumptions(question), infections_df), num_samples=1500).iloc[:,0]"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "We calculated the model in terms of duration in days. Let's convert that to date format\n",
    "\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "for question in questions:\n",
    "   if question.name in samples:\n",
    "     start_date = get_assumptions(question)[\"start\"]\n",
    "     samples[question.name] = samples[question.name].apply(lambda x: start_date + timedelta(days=x))\n"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# Evaluate Models\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "plotnine.options.figure_size = (4,15)\n",
    "model_samples = pd.melt(samples, var_name=\"questions\", value_name=\"samples\")\n",
    "(ggplot(model_samples, aes(\"samples\"))\n",
    "                + geom_histogram()\n",
    "                + facet_wrap('questions', ncol=1)\n",
    "                + scale_x_datetime()\n",
    "                + labs(\n",
    "                    x=\"Prediction\",\n",
    "                    y=\"Counts\",\n",
    "                    title=f\"Samples from each lockdown model\"\n",
    "                )\n",
    "                + theme_bw() \n",
    "                + theme(axis_text_x=element_text(rotation=45, hjust=.5))\n",
    ") "
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Submit Predictions\n",
    "\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "french_submission = questions[0].get_submission_from_samples(samples['France'])"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "french_submission"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### Compare Potential Submission to the Community's Predictions\n",
    "\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "for question in questions:\n",
    "    if question.name in samples:\n",
    "        print(question.show_submission(samples[question.name], show_community=True))"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### Submit from sample\n",
    "\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "def submit_all():\n",
    "  for question in questions:\n",
    "    if question.name in samples:\n",
    "      try:\n",
    "        params = question.submit_from_samples(samples[question.name])\n",
    "        print(f\"Submitted Logistic{params} for {question.name}\")\n",
    "        print(f\"https://pandemic.metaculus.com{question.page_url}\")\n",
    "      except requests.exceptions.HTTPError as e:\n",
    "        print(f\"Couldn't make prediction for {question.name} -- maybe this question is now closed? See error below.\")\n",
    "        print(e)\n",
    "    else:\n",
    "      print(f\"No predictions for {question.name}\")"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "Submit it!\n",
    "\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "#submit_all()"
   ]
  }
 ],
 "metadata": {
  "jupytext": {
   "cell_metadata_filter": "-all",
   "main_language": "python",
   "notebook_metadata_filter": "-all"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 4
}
